import numpy as np

from references.polar.sequence import Sequence


# noinspection PyPep8Naming
class Encoder:
    """Encoder for polar codes.

    A simple matrix implementation of polar encoder. The class is instantiated for
    a given codeword and message size to avoid repeated computation of the generation
    matrix. It uses polar sequence defined in the sequence module.
    """

    def __init__(self, N, K):
        """Create an instance of polar encoder.

        The instance for given codeword and message size is created.

        Args:
            N: Codeword size
            K: Message size.

        Raises:
            ValueError: At least one argument is invalid.
        """
        if N > 1024 or N < 4:
            raise ValueError('Codeword size has to be in range [4, 1024]!')

        self._N = N
        self._K = K

        if self._K > self._N:
            raise ValueError('Message size cannot exceed codeword size!')

        self._polar_sequence = Sequence(self._N, self._K)

        G_base = np.array([[1, 0], [1, 1]], dtype='int16')
        self._G = G_base

        n = int(np.log2(N))
        for _ in range(2, n + 1):
            self._G = np.kron(self._G, G_base)

    def polar_transform(self, c):
        """Perform polar transform on the given input to generate the polar code.

        The definition of the polar code is used to transform the given input into
        the codeword.

        Args:
            c: Array of size N with input data.

        Returns:
            The array of size N with transform of the input.

        Raises:
            ValueError: The input argument is invalid.
        """
        if len(c) != self._N:
            raise ValueError('The codeword size has to be ' + str(self._N) + '!')

        return np.matmul(c, self._G) % 2

    def bits_insertion(self, m):
        """Map message into the codeword.

        Message bits are inserted into non-frozen positions

        Args:
            m: Array of message bits of size K

        Return:
            An array of size N with inserted message bits
        """
        unfrozen_bits = self._polar_sequence.unfrozen_positions

        u = np.zeros(self._N, dtype='int16')
        for i in range(self._K):
            u[unfrozen_bits[i]] = m[i]

        return u

    def encode(self, m):
        """Map message into the codeword and perform the polar transform.

        The codeword is generated by inserting message bits into non-frozen positions
        in the array.

        Args:
            m: Array of message bits of size K.

        Returns:
            The array of size N with polar encoded message.

        Raises:
            ValueError: The input argument is invalid.
        """
        if len(m) != self._K:
            raise ValueError('The message size has to be ' + str(self._K) + '!')

        u = self.bits_insertion(m)

        return self.polar_transform(u)
